package cn.dagteam.springboot.mongodb.starter.listener;

import cn.dagteam.springboot.mongodb.starter.Cascade;
import cn.dagteam.springboot.mongodb.starter.Cascade.CascadeType;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.mapping.event.BeforeDeleteEvent;
import org.springframework.util.ReflectionUtils;

import java.beans.IntrospectionException;
import java.beans.PropertyDescriptor;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;

@Slf4j
@AllArgsConstructor
public class CascadeDeleteCallback implements ReflectionUtils.FieldCallback {

    private BeforeDeleteEvent<Object> event;

    private MongoTemplate mongoTemplate;

    @Override
    public void doWith(Field field) throws IllegalArgumentException {
        try {
            Class<?> sourceClass = event.getType();
            Object source = mongoTemplate.findById(event.getDocument().get("_id"), sourceClass, event.getCollectionName());
            if (Objects.isNull(source)) {
                return;
            }
            PropertyDescriptor pd = new PropertyDescriptor(field.getName(), source.getClass());
            Object value = ReflectionUtils.invokeMethod(pd.getReadMethod(), source);
            if (value != null) {
                Cascade cascade = field.getAnnotation(Cascade.class);
                CascadeType[] types = cascade.value();
                Arrays.stream(types).filter(p -> CascadeType.ALL.equals(p) || CascadeType.DELETE.equals(p)).findAny().ifPresent(p -> {
                    if (value instanceof Collection) {
                        Collection<?> cols = (Collection<?>) value;
                        log.debug("级联删除集合：{}", field.getName());
                        cols.forEach(col -> mongoTemplate.remove(col));
                    } else {
                        log.debug("级联删除对象：{}", field.getName());
                        mongoTemplate.remove(value);
                    }
                });
            }
        } catch (IntrospectionException e) {
            throw new RuntimeException("不是javabean的对象", e);
        }
    }
}
